In this notebook I will look at the prior and posterior predictive generative models for a Network FKPP model.
First, document the environment and load necesary packages
using Pkg
Pkg.status()
Status `~/Projects/NetworkTopology/Project.toml` [b5ca4192] AdvancedVI v0.1.1 [76274a88] Bijectors v0.8.14 [a93c6f00] DataFrames v0.22.5 [2b5f629d] DiffEqBase v6.57.5 `https://github.com/SciML/DiffEqBase.jl.git#sensitivity_interpolation` [41bf760c] DiffEqSensitivity v6.42.0 [0c46a032] DifferentialEquations v6.16.0 [31c24e10] Distributions v0.24.15 [ced4e74d] DistributionsAD v0.6.20 [f6369f11] ForwardDiff v0.10.16 [7073ff75] IJulia v1.23.2 [093fc24a] LightGraphs v1.3.5 [c7f686f2] MCMCChains v4.7.0 [91a5bcdd] Plots v1.10.6 [c3e4b0f8] Pluto v0.12.21 [37e2e3b7] ReverseDiff v1.7.0 [f3b207a7] StatsPlots v0.14.19 [fce5fe82] Turing v0.15.10 [e88e6eb3] Zygote v0.6.3
using Random, DifferentialEquations, Turing, Plots, StatsPlots, MCMCChains, LightGraphs, LinearAlgebra, Base.Threads
┌ Info: Precompiling DifferentialEquations [0c46a032-eb83-5123-abaf-570d42b7fbaa] └ @ Base loading.jl:1278 ┌ Info: Precompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] └ @ Base loading.jl:1278 ┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80] └ @ Base loading.jl:1278 ┌ Warning: Module JSON with build ID 1137324951110975 is missing from the cache. │ This may mean JSON [682c06a0-de6a-54ab-a142-c8b1cf79cde6] does not support precompilation but is imported by a module that does. └ @ Base loading.jl:1017 ┌ Info: Skipping precompilation since __precompile__(false). Importing Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]. └ @ Base loading.jl:1034 ┌ Info: Precompiling RecipesPipeline [01d81517-befc-4cb6-b9ec-a95719d0359c] └ @ Base loading.jl:1278 ┌ Info: Precompiling GR [28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71] └ @ Base loading.jl:1278 ┌ Info: Precompiling StatsPlots [f3b207a7-027a-5e70-b257-86293d7955fd] └ @ Base loading.jl:1278 ┌ Warning: Module Plots with build ID 2768756145298387 is missing from the cache. │ This may mean Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80] does not support precompilation but is imported by a module that does. └ @ Base loading.jl:1017 ┌ Info: Skipping precompilation since __precompile__(false). Importing StatsPlots [f3b207a7-027a-5e70-b257-86293d7955fd]. └ @ Base loading.jl:1034 ┌ Info: Precompiling Widgets [cc8bc4a8-27d6-5769-a93b-9d913e69aa62] └ @ Base loading.jl:1278 ┌ Info: Precompiling Interpolations [a98d9a8b-a2ab-59e6-89dd-64a1c18fca59] └ @ Base loading.jl:1278 ┌ Info: Precompiling KernelDensity [5ab0869b-81aa-558d-bb23-cbf5423bbe9b] └ @ Base loading.jl:1278
plotly()
turing.setadbackend(:forwarddiff);
Random.seed!(1);
┌ Info: For saving to png with the Plotly backend PlotlyBase has to be installed. └ @ Plots /home/chaggar/.julia/packages/Plots/SjqWU/src/backends.jl:372
UndefVarError: turing not defined
Stacktrace:
[1] top-level scope at In[4]:2
[2] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
As with the diffusion model, the first step will be to construct a random network using LightGraphs.
After this, we can define the Network FKPP equation:
$\frac{d\mathbf{p}_i}{dt} = -k \sum\limits_{j=1}^{N}\mathbf{L}_{ij}^{\omega}\mathbf{p}_j + \alpha \mathbf{p}_i\left(1-\mathbf{p}_i\right)$"
function make_graph(N::Int64, P::Float64)
G = erdos_renyi(N, P)
L = laplacian_matrix(G)
A = adjacency_matrix(G)
return L, A
end
make_graph (generic function with 1 method)
function NetworkFKPP(u, p, t)
κ, α = p
du = -κ * L * u .+ α .* u .* (1 .- u)
end
NetworkFKPP (generic function with 1 method)
We can now solve the model for the network defined with five nodes. We do this using DifferentialEquations, using an adaptive step size numerical method.
We start by making a graph Erdos-Renyi random graph with connection probability of 0.5.
N = 5
P = 0.5
L, A = make_graph(N, P);
Then we can set initial conditions to initialise the ODEProblem and numerically solve.
u0 = rand(N)
p = 1.5, 3
t_span = (0.0,2.0)
(0.0, 2.0)
problem = ODEProblem(NetworkFKPP, u0, t_span, p);
sol = solve(problem, AutoTsit5(Rosenbrock23()), saveat=0.05);
Next, we'll discretise this solution and add noise to simulate synethic data.
data = clamp.(Array(sol) + 0.02 * randn(size(Array(sol))), 0.0,1.0);
plot(Array(sol)')
scatter!(data')
Now that we have a model and some synthetic data, we can create a generative model in Turing that aims to capture the data probablistically. In our case, the generative model is defined for data:
Our probablistic represention of these model parameters are given by:
$$\sigma \approx \Gamma^{-1}(2, 3)$$$$\kappa \approx \mathcal{N}(5,10,[0,10])$$ $$\alpha \approx \mathcal{N}(5,10,[0,10])$$ $$\mathbf{u0} \approx \mathcal{N}(0,2,[0,1])$$
@model function fitode(data, problem)
σ ~ InverseGamma(2, 3)
k ~ truncated(Normal(5,10.0),0.0,10)
a ~ truncated(Normal(5,10.0),0.0,10)
u ~ filldist(truncated(Normal(0.5,2.0),0.0,1.0), 5)
p = [k, a]
prob = remake(problem, u0=u, p=p)
predicted = solve(prob, AutoTsit5(Rosenbrock23()), saveat=0.05)
for i ∈ 1:length(predicted)
data[:,i] ~ MvNormal(predicted[i], σ)
end
end
fitode (generic function with 1 method)
prior_chain = sample(fitode(data,problem), Prior(), 10_000)
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:05
Chains MCMC chain (10000×9×1 Array{Float64,3}):
Iterations = 1:10000
Thinning interval = 1
Chains = 1
Samples per chain = 10000
parameters = a, k, u[1], u[2], u[3], u[4], u[5], σ
internals = lp
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
a 4.9929 2.8288 0.0283 0.0293 10027.0727 1.0000
k 5.0025 2.8452 0.0285 0.0253 9971.3402 1.0000
u[1] 0.4965 0.2877 0.0029 0.0031 10005.5924 0.9999
u[2] 0.5054 0.2889 0.0029 0.0034 9870.0291 0.9999
u[3] 0.4974 0.2867 0.0029 0.0032 9270.1321 1.0000
u[4] 0.4986 0.2871 0.0029 0.0028 10469.1632 0.9999
u[5] 0.5019 0.2869 0.0029 0.0031 9234.2871 0.9999
σ 2.9458 4.2150 0.0422 0.0372 9945.9505 0.9999
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
a 0.2656 2.6167 4.9730 7.3983 9.7417
k 0.2584 2.5587 4.9647 7.4672 9.7151
u[1] 0.0290 0.2474 0.4959 0.7446 0.9774
u[2] 0.0243 0.2573 0.5074 0.7563 0.9768
u[3] 0.0242 0.2475 0.5041 0.7442 0.9715
u[4] 0.0270 0.2505 0.4930 0.7452 0.9750
u[5] 0.0252 0.2548 0.5064 0.7489 0.9725
σ 0.5376 1.1168 1.7917 3.1484 12.9017
chain_array = Array(prior_chain);
function plot_priorpredictive(node, sol, data, chain)
plot(Array(sol)[node,:], w=2, legend = false)
for k in 1:500
par = chain[rand(1:10_000), 1:8]
resol = solve(remake(problem,u0=par[3:7], p=par[1:2]),AutoTsit5(Rosenbrock23()),saveat=0.05)
plot!(Array(resol)[node,:], alpha=0.5, color = "#BBBBBB", legend = false)
end
return scatter!(data[node,:], legend = false)
end
plot_priorpredictive (generic function with 1 method)
plot_priorpredictive(1, sol, data, chain_array)
plot_priorpredictive(2, sol, data, chain_array)
plot_priorpredictive(3, sol, data, chain_array)
plot_priorpredictive(4, sol, data, chain_array)
plot_priorpredictive(5, sol, data, chain_array)
Matrix(A)
5×5 Array{Int64,2}:
0 0 0 0 0
0 0 1 1 0
0 1 0 0 1
0 1 0 0 0
0 0 1 0 0
heatmap(Matrix(L))
The above plots show that the prior may not be well defined for our data. For the most part, the prior predictive simulations are left shifted. This is likely due to the prior placed on the diffusion and growth constants, i.e.:
$$\kappa \approx \mathcal{N}(5,10,[0,10])$$$$\alpha \approx \mathcal{N}(5,10,[0,10])$$Let's plot this distibution to see how wide it is.
model_prior = truncated(Normal(5, 10), 0, 10)
plot(model_prior, ylims=(0.0,0.2))
The distribution is relatively flat around the mean and betweent the bounds, as desired. The results above are almost certainly due to the true values not being close to the median of the distibution and therefore there is a greater chance of higher values, shifting the dynamics to the left.
Next, let's try sampling from the posterior using NUTS to see if the generative model is able to capture the dynamics of the data given our model.
model = fitode(data,problem);
chain = sample(model, NUTS(0.65), MCMCThreads(), 1_000, 10, progress=true)
┌ Info: Found initial step size
│ ϵ = 0.2
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
┌ Info: Found initial step size
│ ϵ = 0.05
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
┌ Info: Found initial step size
│ ϵ = 0.2
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
┌ Info: Found initial step size
│ ϵ = 0.2
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
┌ Info: Found initial step size
│ ϵ = 0.2
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
┌ Info: Found initial step size
│ ϵ = 0.05
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
┌ Info: Found initial step size
│ ϵ = 0.2
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
┌ Info: Found initial step size
│ ϵ = 0.2
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
┌ Info: Found initial step size
│ ϵ = 0.2
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
┌ Info: Found initial step size
│ ϵ = 0.2
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
┌ Warning: The current proposal will be rejected due to numerical error(s).
│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC /home/chaggar/.julia/packages/AdvancedHMC/MIxdK/src/hamiltonian.jl:47
Sampling (10 threads): 100%|████████████████████████████| Time: 0:00:11
Chains MCMC chain (1000×20×10 Array{Float64,3}):
Iterations = 1:1000
Thinning interval = 1
Chains = 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
Samples per chain = 1000
parameters = a, k, u[1], u[2], u[3], u[4], u[5], σ
internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
a 2.9780 0.0416 0.0004 0.0005 6304.0784 1.0013
k 1.6177 0.1179 0.0012 0.0017 5206.8462 0.9998
u[1] 0.1890 0.0055 0.0001 0.0001 7004.9200 1.0011
u[2] 0.0523 0.0144 0.0001 0.0002 5894.6809 1.0003
u[3] 0.1233 0.0134 0.0001 0.0002 5962.3108 1.0005
u[4] 0.5161 0.0146 0.0001 0.0002 5317.4219 1.0009
u[5] 0.4103 0.0128 0.0001 0.0002 5680.0437 0.9997
σ 0.0276 0.0016 0.0000 0.0000 7885.0297 1.0002
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
a 2.8991 2.9498 2.9777 3.0058 3.0615
k 1.3943 1.5380 1.6145 1.6949 1.8580
u[1] 0.1784 0.1852 0.1890 0.1927 0.2000
u[2] 0.0233 0.0427 0.0526 0.0621 0.0801
u[3] 0.0964 0.1144 0.1232 0.1324 0.1496
u[4] 0.4879 0.5060 0.5161 0.5260 0.5445
u[5] 0.3857 0.4016 0.4101 0.4186 0.4362
σ 0.0246 0.0265 0.0276 0.0287 0.0310
plot(chain)
posterior_chain_array = Array(chain);
plot_priorpredictive(1, sol, data, posterior_chain_array)
plot_priorpredictive(2, sol, data, posterior_chain_array)
plot_priorpredictive(3, sol, data, posterior_chain_array)
plot_priorpredictive(4, sol, data, posterior_chain_array)
plot_priorpredictive(5, sol, data, posterior_chain_array)
means = mean(posterior_chain_array, dims=1)
1×8 Array{Float64,2}:
2.97797 1.61769 0.189002 0.0523408 … 0.516074 0.410305 0.0276413
prob = remake(problem, u0=means[:3:7], p=[1.5, 3])
predicted = solve(prob, Tsit5(), saveat=0.05)
retcode: Success
Interpolation: 1st order linear
t: 41-element Array{Float64,1}:
0.0
0.05
0.1
0.15
0.2
0.25
0.3
0.35
0.4
0.45
0.5
0.55
0.6
⋮
1.45
1.5
1.55
1.6
1.65
1.7
1.75
1.8
1.85
1.9
1.95
2.0
u: 41-element Array{Array{Float64,1},1}:
[0.18900191606653005, 0.052340765589763985, 0.12329241909098175, 0.5160737575060762, 0.41030528421763174]
[0.2130716811254749, 0.10078849396487977, 0.1574145869167381, 0.5204087984099857, 0.42595006929387946]
[0.23930227023953193, 0.15086894442350388, 0.19477638838334344, 0.5279559814890817, 0.44340180462163725]
[0.26766371227466745, 0.20224175729215324, 0.23504467042777694, 0.5385505110699668, 0.4626985068867582]
[0.29806937656046295, 0.25448464207512433, 0.27778049938636934, 0.5519888866679504, 0.4838184760060658]
[0.33037082316447863, 0.3071072905950899, 0.3224488508145135, 0.5680228658362384, 0.5066737677499422]
[0.3643557075090992, 0.35957914893001724, 0.3684426877981422, 0.5863569454533702, 0.5311078265892981]
[0.3997501836639185, 0.4113588451651066, 0.41510394087374203, 0.6066479308043575, 0.556899841265691]
[0.43622215211161364, 0.46190837604603996, 0.4617632308698226, 0.6285175018119933, 0.5837654396200448]
[0.4733978236606727, 0.5107362073367101, 0.5077776137432601, 0.6515636551127736, 0.6113778992542499]
[0.5108718226982703, 0.5574220045918541, 0.5525442612436104, 0.6753641700970929, 0.6393807981123627]
[0.5482231200321518, 0.6016134219742375, 0.5955453394289276, 0.6995062851659065, 0.6674013994835066]
[0.5850375792697491, 0.6430326459980766, 0.6363669163563608, 0.7236085709546498, 0.6950764457439947]
⋮
[0.9475300035348915, 0.9635715861372693, 0.9627695599145453, 0.9650166763110463, 0.9630631140179903]
[0.9545143492141359, 0.9685373294647065, 0.9678685271702234, 0.9696514141716945, 0.9680283809661874]
[0.9606040491028847, 0.9728319516664277, 0.9722755003785938, 0.9736983056984442, 0.972350683849285]
[0.9659011216281838, 0.9765389598957516, 0.9760770578945825, 0.9772260230539705, 0.976106710358415]
[0.9705029938271214, 0.9797400674124231, 0.9793573498387919, 0.9802956274918833, 0.979365651871566]
[0.9745024854820343, 0.9825151529947761, 0.9821980607336962, 0.9829605879723431, 0.9821892159468205]
[0.9779787470071619, 0.9849242506558737, 0.9846615021869285, 0.9852726138735929, 0.984634719520544]
[0.9809917331551533, 0.9870057680294403, 0.986788286999053, 0.9872776252928945, 0.9867506938660786]
[0.9835975155357509, 0.9888001013341654, 0.9886202099922743, 0.9890137601426363, 0.9885784542601568]
[0.9858481490487405, 0.9903450904600708, 0.9901963503266282, 0.9905156278865888, 0.9901556779393509]
[0.9877916718839517, 0.9916760189685101, 0.9915530714999823, 0.9918143095399051, 0.9915164041000731]
[0.9894721055212574, 0.9928256140921686, 0.9927240213480354, 0.9929373576691208, 0.9926910338985752]
plot(Array(predicted)')
scatter!(data')